# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from dataclasses import dataclass
from typing import Any, Dict, Mapping

from promptflow.contracts.run_info import FlowRunInfo, RunInfo


@dataclass
class LineResult:
    """The result of a line process."""

    output: Mapping[str, Any]  # The output of the line.
    # The node output values to be used as aggregation inputs, if no aggregation node, it will be empty.
    aggregation_inputs: Mapping[str, Any]
    run_info: FlowRunInfo  # The run info of the line.
    node_run_infos: Mapping[str, RunInfo]  # The run info of the nodes in the line.

    @staticmethod
    def deserialize(data: dict) -> "LineResult":
        """Deserialize the LineResult from a dict."""
        return LineResult(
            output=data.get("output"),
            aggregation_inputs=data.get("aggregation_inputs", {}),
            run_info=FlowRunInfo.deserialize(data.get("run_info")),
            node_run_infos={k: RunInfo.deserialize(v) for k, v in data.get("node_run_infos", {}).items()},
        )


@dataclass
class AggregationResult:
    """The result when running aggregation nodes in the flow."""

    output: Mapping[str, Any]  # The output of the aggregation nodes in the flow.
    metrics: Dict[str, Any]  # The metrics generated by the aggregation.
    node_run_infos: Mapping[str, RunInfo]  # The run info of the aggregation nodes.

    @staticmethod
    def deserialize(data: dict) -> "AggregationResult":
        """Deserialize the AggregationResult from a dict."""
        return AggregationResult(
            output=data.get("output", None),
            metrics=data.get("metrics", None),
            node_run_infos={k: RunInfo.deserialize(v) for k, v in data.get("node_run_infos", {}).items()},
        )
